# -*- coding: utf-8 -*-
"""
ALAAM 驱动机制识别（two-path 正交 + 自适应阈值 + 稳健兜底 + 精简入模 + 2x2 GoF）
- 保证：spatial_lag_misfit_t1_z、corridor_betweenness_t1_z 保留
- 强化：two_path_open_t1_ortho_z（对 logdeg/corridor/spatial_lag 正交，逐年z）
- 若 two_path_open_t1_ortho_z p>=阈值，自适应改用 two_path_topQ4_ortho_t1（同年Q4虚拟）
- 省级聚类稳健 SE；新增 PR 曲线 & Gains 曲线；GoF 图 2x2；GoF 指标新增 AUPRC 与 KS
- 精简入模：首次全量拟合后，剔除极不显著变量，再在“精简RHS”上完成 Step1/Step2

特别说明：
- 按你的要求：输入面板路径已改为  D:\Fragmentation of MPA governance\mpa_network_misfit\pythonProject1\.venv\alaam_nodes_panel.csv
- 不再加入 spatial_lag_topQ4_t1（ADD_SPATIAL_TOPQ4=False）
"""

# ========= 固定路径 & 无头绘图 =========
from pathlib import Path
import os
import numpy as np
import pandas as pd
import statsmodels.formula.api as smf
import statsmodels.api as sm
from math import erf

# 强制使用无头后端，保证能保存 PNG（即使没有显示器）
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
MATPLOTLIB_OK = True

def _resolve_panel_path():
    # 候选：相对路径 → /mnt/data → 你的两条 Windows 路径
    cands = [
        Path("outputs_alaam/alaam/alaam_nodes_panel.csv"),
        Path("./alaam_nodes_panel.csv"),
        Path("/mnt/data/outputs_alaam/alaam/alaam_nodes_panel.csv"),
        Path(r"D:\Fragmentation of MPA governance\mpa_network_misfit\pythonProject1\.venv\outputs_alaam\alaam\alaam_nodes_panel.csv"),
        Path(r"D:\下载\outputs_alaam\alaam\alaam_nodes_panel.csv"),
    ]
    for p in cands:
        if p.exists():
            return p.resolve()
    raise FileNotFoundError(
        "找不到 alaam_nodes_panel.csv，请检查网络错配脚本的输出，或把路径加入 _resolve_panel_path() 的候选列表。"
    )

PANEL_PATH = _resolve_panel_path()
# 输出目录设为面板同目录，方便统一归档（gof图会写到这里）
OUT_DIR = PANEL_PATH.parent
OUT_DIR.mkdir(parents=True, exist_ok=True)

print(f"[ALAAM] Using panel: {PANEL_PATH}")
print(f"[ALAAM] Outputs to: {OUT_DIR}")


# ========= 配置 =========
Y_COL = 'Y_t'
CLUSTER_BY = 'province'
N_CALIB_BINS = 8
RANDOM_SEED = 42
KEEP_YEARS = None  # None 表示不筛年份；给列表才会筛


# two-path 处理
WINSOR_TWOPATH_Q = 0.99            # 按年 1%-99% 截尾
TP_BASES_FOR_ORTHO = [
    'logdeg_t1',                 # 原始度（与 triangles 的冗余由主面板阶段已处理）
    'corridor_betweenness_t1_z', # 走廊 z
    'spatial_lag_misfit_t1_z'    # 空间错配 z
]
TP_SWITCH_PVAL = 0.15              # 自适应阈值：若 p>=0.15 改用 Q4 虚拟项（可调）

# 其他自变量（优先 *_z）
# 说明：本版不再加入 spatial_lag_topQ4_t1
BASE_CONT_Z = [
    'logdeg_t1_z',
    'spatial_lag_misfit_t1_z',
    'z_prov_count_looi_t1_z',
    'nonNat_share_t1_z',
    'corridor_betweenness_t1_z'
]
BINARY_COLS = ['high_misfit_t1']
ADD_SPATIAL_TOPQ4 = False          # ❗ 禁用

# 精简入模：首次全量拟合后剔除的显著性阈值（越高越宽松，这里认为 p>=0.50 为“极不显著”）
PRUNE_PVAL = 0.50
# 结构性（内生）最少保留的控制项（确保有一项内生结构控制）
STRUCTURAL_MIN_KEEP = ['logdeg_t1_z']
# 强制保留：关键外生驱动 & 核心估计项（two_path 在流程中单独处理）
FORCE_KEEP = ['spatial_lag_misfit_t1_z', 'corridor_betweenness_t1_z']

# ========= I/O & 工具 =========

Y_COL = 'Y_t'          # 因变量列名（保持与面板一致）
KEEP_YEARS = [2020, 2025]  # 仅保留 2020/2025；若不筛选可设为 None

def load_panel(panel_path: Path | None = None) -> pd.DataFrame:
    """读取面板；自动兼容 Year/SeaRegion 列名；可选年份过滤。"""
    p = PANEL_PATH if panel_path is None else Path(panel_path)
    if not p.exists():
        raise FileNotFoundError(f'找不到面板文件：{p}')
    print('[info] 读取面板：', p.resolve())
    df = pd.read_csv(p)

    # 列名对齐：网络脚本输出的是 Year / SeaRegion（首字母大写）
    rename = {}
    if 'Year' in df.columns:      rename['Year'] = 'year'
    if 'SeaRegion' in df.columns: rename['SeaRegion'] = 'sea_region'
    if 'Province' in df.columns:  rename['Province'] = 'province'
    df = df.rename(columns=rename)

    # 年份过滤（如需）
    if 'year' in df.columns and KEEP_YEARS:
        years_keep = [y for y in KEEP_YEARS if y in df['year'].unique().tolist()]
        if years_keep:
            df = df[df['year'].isin(years_keep)].copy()
            print(f'[info] 已按 KEEP_YEARS 过滤 year ∈ {years_keep}，剩余 {len(df)} 行')

    if Y_COL not in df.columns:
        raise KeyError(f'面板缺少目标列 {Y_COL}')
    df[Y_COL] = df[Y_COL].astype(int)

    # 若过滤后 Y_t 没有变异，则回退到未过滤的完整面板
    if df[Y_COL].nunique() < 2 and 'year' in df.columns and KEEP_YEARS:
        print('[auto] 发现 Y_t 无变异，自动取消年份过滤以恢复样本。')
        df = pd.read_csv(p).rename(columns=rename)
        df[Y_COL] = df[Y_COL].astype(int)

    return df

# —— 在定义之后再调用 ——
df = load_panel()



def winsorize_by_year(df, col, q=0.99, year_col='year'):
    if q is None or col not in df.columns:
        return df[col].astype(float).copy() if col in df.columns else pd.Series(np.nan, index=df.index)
    out = df[col].astype(float).copy()
    if year_col not in df.columns:
        lo, hi = out.quantile(1-q), out.quantile(q)
        return np.clip(out.values.astype(float), lo, hi)
    for y, d in df.groupby(year_col):
        lo, hi = d[col].quantile(1-q), d[col].quantile(q)
        out.loc[d.index] = np.clip(d[col].values.astype(float), lo, hi)
    return out

def z_by_year(series, years):
    out = pd.Series(np.nan, index=series.index)
    if years is None:
        v = series.astype(float)
        mu, sd = v.mean(), v.std(ddof=0)
        return pd.Series(0.0, index=series.index) if (sd==0 or np.isnan(sd)) else (v-mu)/sd
    for y in np.unique(years):
        m = (years == y)
        v = series[m].astype(float)
        mu, sd = v.mean(), v.std(ddof=0)
        out.loc[m] = 0.0 if (sd==0 or np.isnan(sd)) else (v-mu)/sd
    return out

def residualize_by_year(df, target, bases, add_z=True, name_prefix='tp_ortho'):
    """target 对 bases（含常数）回归，分年取残差；可返回逐年z。"""
    res = pd.Series(np.nan, index=df.index, name=f'{name_prefix}_res')
    if 'year' not in df.columns:
        # 无 year 情况下，改为全样本残差
        cols = [c for c in bases if c in df.columns] + ([target] if target in df.columns else [])
        if len(cols) >= 2:
            sub = df[cols].dropna()
            if len(sub) >= 5:
                X = [np.ones(len(sub))]
                for b in bases:
                    if b in sub.columns:
                        X.append(sub[b].astype(float).values)
                X = np.column_stack(X); yv = sub[target].astype(float).values
                beta = np.linalg.lstsq(X, yv, rcond=None)[0]
                res.loc[sub.index] = yv - (X @ beta)
        if add_z:
            z = z_by_year(res, None); z.name = f'{name_prefix}_z'
            return res, z
        return res, None

    for y, d in df.groupby('year'):
        cols = [c for c in bases if c in d.columns] + ([target] if target in d.columns else [])
        if len(cols) < 2:
            continue
        sub = d[cols].dropna()
        if len(sub) < 5:
            continue
        X = [np.ones(len(sub))]
        for b in bases:
            if b in sub.columns:
                X.append(sub[b].astype(float).values)
        X = np.column_stack(X)
        yv = sub[target].astype(float).values
        beta = np.linalg.lstsq(X, yv, rcond=None)[0]
        yhat = X @ beta
        res.loc[sub.index] = yv - yhat
    if add_z:
        z = z_by_year(res, df['year'] if 'year' in df.columns else None)
        z.name = f'{name_prefix}_z'
        return res, z
    return res, None

def add_topQ4_dummy(df, base_col, out_col):
    if base_col not in df.columns:
        df[out_col] = np.nan
        return df
    if 'year' in df.columns:
        thr = df.groupby('year')[base_col].transform(lambda s: s.quantile(0.75))
    else:
        thr = df[base_col].quantile(0.75)
    df[out_col] = (df[base_col] >= thr).astype(int)
    return df

# ========= 稳健协方差与兜底包装 =========
def _wrap_with_cov(res, cov):
    return type('Wrap', (), {
        'params': getattr(res, 'params', pd.Series([np.nan], index=['Intercept'])),
        'bse':    (np.sqrt(np.diag(cov)) if cov is not None else getattr(res, 'bse', pd.Series([np.nan], index=['Intercept']))),
        'predict': getattr(res, 'predict', lambda X=None: np.full(getattr(res, 'nobs', 0), np.nan)),
        'llf':     getattr(res, 'llf', np.nan),
        'llnull':  getattr(res, 'llnull', np.nan),
    })()

def _as_robust(res, cluster_series=None):
    try:
        if cluster_series is not None:
            return res.get_robustcov_results(cov_type='cluster', groups=cluster_series)
        else:
            return res.get_robustcov_results(cov_type='HC1')
    except Exception:
        pass
    try:
        if cluster_series is not None:
            return res._get_robustcov_results(cov_type='cluster', groups=cluster_series)
        else:
            return res._get_robustcov_results(cov_type='HC1')
    except Exception:
        pass
    try:
        from statsmodels.stats.sandwich_covariance import cov_cluster, cov_hc1
        cov = cov_cluster(res, cluster_series) if cluster_series is not None else cov_hc1(res)
        return _wrap_with_cov(res, cov)
    except Exception:
        try:
            cov = res.cov_params()
        except Exception:
            cov = None
        return _wrap_with_cov(res, cov)

def _constant_result(d, base_prob=None):
    y = d[Y_COL].astype(int).values if Y_COL in d.columns else np.array([], dtype=int)
    if base_prob is None:
        p = float(np.clip(y.mean() if y.size else 0.5, 1e-6, 1-1e-6))
    else:
        p = float(np.clip(base_prob, 1e-6, 1-1e-6))
    logit = np.log(p/(1-p))
    params = pd.Series([logit], index=['Intercept'])
    bse = pd.Series([np.nan], index=['Intercept'])
    def _predict(X=None): return np.full(len(d), p, dtype=float)
    return type('Wrap', (), {
        'params': params, 'bse': bse, 'predict': _predict, 'llf': np.nan, 'llnull': np.nan
    })()

# ========= 拟合器 =========
def fit_logit(df, formula):
    # 丢缺失
    rhs_terms = [t.strip() for t in formula.split('~')[1].split('+')]
    raw_terms = [t for t in rhs_terms if not t.startswith('C(') and t!='1']
    need = [Y_COL]
    if CLUSTER_BY in df.columns:  # 仅当存在时才作为必需列
        need.append(CLUSTER_BY)
    need += raw_terms
    d = df.dropna(subset=[c for c in need if c in df.columns]).copy()

    if d.empty or d[Y_COL].nunique()<2:
        print('[warn] 样本为空或 Y 无变异：使用常数模型。')
        return d, _constant_result(d), f"{Y_COL} ~ 1"

    # 移除无变异 RHS
    keep = [t for t in raw_terms if (t in d.columns and d[t].nunique(dropna=True)>1)]
    fe = []
    if 'year' in d.columns and d['year'].nunique()>1: fe.append('C(year)')
    if 'sea_region' in d.columns and d['sea_region'].nunique()>1: fe.append('C(sea_region)')
    f = f"{Y_COL} ~ " + (' + '.join(keep + fe) if keep or fe else '1')

    # 1) Logit
    try:
        res = smf.logit(f, data=d).fit(disp=False, maxiter=500)
        robust = _as_robust(res, d[CLUSTER_BY] if CLUSTER_BY in d.columns else None)
        if robust is None: robust = _wrap_with_cov(res, getattr(res, 'cov_params', lambda: None)())
        return d, robust, f
    except Exception as e1:
        print('[warn] Logit 失败，改 L2 正则：', str(e1))

    # 2) 正则 Logit（L2）
    try:
        model = smf.logit(f, data=d)
        res = model.fit_regularized(method='l1', alpha=1.0, L1_wt=0.0, maxiter=1000)
        res.predict = lambda X=None: model.predict(res.params, exog=None)
        robust = _as_robust(res, d[CLUSTER_BY] if CLUSTER_BY in d.columns else None)
        if robust is None: robust = _wrap_with_cov(res, getattr(res, 'cov_params', lambda: None)())
        return d, robust, f
    except Exception as e2:
        print('[warn] 正则 Logit 失败，改 GLM：', str(e2))

    # 3) GLM Binomial
    try:
        res = smf.glm(f, data=d, family=sm.families.Binomial()).fit()
        robust = _as_robust(res, d[CLUSTER_BY] if CLUSTER_BY in d.columns else None)
        if robust is None: robust = _wrap_with_cov(res, getattr(res, 'cov_params', lambda: None)())
        return d, robust, f
    except Exception as e3:
        print('[warn] GLM 失败，回退常数模型：', str(e3))
        return d, _constant_result(d), f

# ========= 系数表 & GoF =========
def extract_coefs(res, label):
    try:
        p = pd.Series(getattr(res, 'params', pd.Series([np.nan], index=['Intercept'])))
        se = getattr(res, 'bse', None)
        if se is None:
            se = pd.Series(np.nan, index=p.index)
        else:
            se = pd.Series(se, index=p.index)
        z = p / se.replace(0, np.nan)
        def _phi(x):
            x = np.asarray(x, dtype=float)
            return 0.5*(1 + np.array([erf(v/np.sqrt(2.0)) for v in x]))
        pval = 2*(1-_phi(np.abs(z.fillna(0.0).values)))
        ci_l, ci_u = p-1.96*se, p+1.96*se
        OR, OR_l, OR_u = np.exp(p), np.exp(ci_l), np.exp(ci_u)
        return pd.DataFrame({'model':label, 'term':p.index,
                             'coef':p.values, 'std_err':se.values, 'z':z.values,
                             'p_value':pval, 'OR':OR.values, 'OR_2.5%':OR_l.values, 'OR_97.5%':OR_u.values})
    except Exception:
        return pd.DataFrame({'model':[label], 'term':['Intercept'],
                             'coef':[np.nan], 'std_err':[np.nan], 'z':[np.nan],
                             'p_value':[np.nan], 'OR':[np.nan], 'OR_2.5%':[np.nan], 'OR_97.5%':[np.nan]})

def _roc(y, s):
    order = np.argsort(-s); y=y[order]; s=s[order]
    P=(y==1).sum(); N=(y==0).sum()
    if P==0 or N==0: return np.array([0,1]), np.array([0,1])
    tp=fp=0; T=[0.0]; F=[0.0]; last=None
    for yi,si in zip(y,s):
        if last is None or si!=last: T.append(tp/P); F.append(fp/N); last=si
        if yi==1: tp+=1
        else: fp+=1
    T.append(1.0); F.append(1.0)
    return np.array(F), np.array(T)

def _auc(y, s):
    order = np.argsort(s); y=y[order]
    n1=(y==1).sum(); n0=(y==0).sum()
    if n1==0 or n0==0: return np.nan
    ranks=np.arange(1,len(y)+1); R1=ranks[y==1].sum()
    return (R1 - n1*(n1+1)/2)/(n1*n0)

def _calib(y, p, bins=8):
    q=np.quantile(p, np.linspace(0,1,bins+1)); q[0]=0; q[-1]=1
    idx=np.digitize(p, q[1:-1], right=True)
    xm=[]; ym=[]; n=[]
    for b in range(bins):
        m=(idx==b)
        if m.sum()==0: xm.append(np.nan); ym.append(np.nan); n.append(0)
        else: xm.append(p[m].mean()); ym.append(y[m].mean()); n.append(int(m.sum()))
    return np.array(xm), np.array(ym), np.array(n)

def _precision_recall(y, s):
    # thresholds from high to low
    order = np.argsort(-s)
    y = y[order]; s = s[order]
    P = (y==1).sum()
    if P==0:
        return np.array([0.0, 1.0]), np.array([1.0, 0.0])
    tp=0; fp=0
    precisions=[]; recalls=[]
    last=None
    for yi,si in zip(y,s):
        if last is None or si != last:
            if tp+fp>0:
                precisions.append(tp/(tp+fp))
                recalls.append(tp/P)
            last=si
        if yi==1: tp+=1
        else: fp+=1
    if tp+fp>0:
        precisions.append(tp/(tp+fp))
        recalls.append(tp/P)
    if len(recalls) == 0:
        return np.array([0.0, 1.0]), np.array([1.0, 0.0])
    return np.array(recalls), np.array(precisions)

def _auprc(y, s):
    r, p = _precision_recall(y, s)
    if len(r)<2: return np.nan
    order = np.argsort(r)
    r = r[order]; p = p[order]
    return float(np.trapz(p, r))

def _gains_curve(y, s):
    order = np.argsort(-s)
    y = y[order]
    cum_pos = np.cumsum(y==1)
    P = (y==1).sum()
    n = len(y)
    if n==0 or P==0:
        return np.array([0,1]), np.array([0,1])
    frac = np.arange(1, n+1)/n
    gains = cum_pos / P
    frac = np.concatenate([[0.0], frac])
    gains = np.concatenate([[0.0], gains])
    return frac, gains

def _ks_stat(y, s):
    fpr, tpr = _roc(y, s)
    return float(np.max(np.abs(tpr - fpr)))

def gof_plot(df, res, out_png, label='pooled'):
    y = df[Y_COL].values.astype(int)
    try:
        p = res.predict() if hasattr(res, 'predict') else np.repeat(y.mean(), len(y))
    except Exception:
        p = np.repeat(y.mean(), len(y))
    auc = _auc(y, p); brier=float(((y-p)**2).mean())
    try:
        mcf = 1 - float(getattr(res, 'llf', np.nan))/float(getattr(res, 'llnull', np.nan))
    except Exception:
        mcf = np.nan

    # 新增：AUPRC 与 KS
    auprc = _auprc(y, p)
    ks = _ks_stat(y, p)

    # 写指标
    cols = ['model','AUC','AUPRC','KS','Brier','McFadden_R2','n','pos_rate']
    vals = [{'model':label,'AUC':auc,'AUPRC':auprc,'KS':ks,'Brier':brier,'McFadden_R2':mcf,
             'n':len(y),'pos_rate':float(y.mean())}]
    fp = OUT_DIR/'alaam_gof_metrics.csv'
    if fp.exists():
        pd.DataFrame(vals)[cols].to_csv(fp, mode='a', header=False, index=False, encoding='utf-8-sig')
    else:
        pd.DataFrame(vals)[cols].to_csv(fp, index=False, encoding='utf-8-sig')

    if not MATPLOTLIB_OK:
        print('[warn] matplotlib 不可用，跳过 GoF 图。'); return

    # 2x2 子图：ROC | Calibration / PR | Gains
    fig, axes = plt.subplots(2,2, figsize=(11.2,8.8))
    axROC = axes[0,0]; axCAL=axes[0,1]; axPR=axes[1,0]; axGAIN=axes[1,1]

    # ROC
    fpr,tpr=_roc(y,p); axROC.plot(fpr,tpr,linewidth=2); axROC.plot([0,1],[0,1],'--')
    axROC.set_title('ROC'); axROC.set_xlabel('False Positive Rate'); axROC.set_ylabel('True Positive Rate')

    # Calibration
    x,yc,n=_calib(y,p,bins=N_CALIB_BINS); axCAL.plot([0,1],[0,1],'--')
    axCAL.scatter(x,yc,s=np.maximum(n,1)); axCAL.set_title('Calibration')
    axCAL.set_xlabel('Mean predicted probability (bin)'); axCAL.set_ylabel('Observed fraction positive (bin)')

    # PR
    r,prec=_precision_recall(y,p)
    if len(r)>1:
        axPR.plot(r,prec,linewidth=2)
    axPR.set_xlim(0,1); axPR.set_ylim(0,1)
    axPR.set_title('Precision–Recall'); axPR.set_xlabel('Recall'); axPR.set_ylabel('Precision')

    # Gains
    frac,gain = _gains_curve(y,p)
    axGAIN.plot(frac,gain,linewidth=2); axGAIN.plot([0,1],[0,1],'--')
    axGAIN.set_title('Cumulative Gains'); axGAIN.set_xlabel('Sample fraction'); axGAIN.set_ylabel('Cumulative positives captured')

    fig.subplots_adjust(bottom=0.10, wspace=0.30, hspace=0.35)
    txt=f'AUC={auc:.3f}  AUPRC={auprc:.3f}  KS={ks:.3f}  Brier={brier:.3f}  McFadden R²={mcf:.3f}  n={len(y)}  Pos={y.mean():.3f}'
    fig.text(0.5,0.03,txt,ha='center',va='center',fontsize=11)
    fig.suptitle('ALAAM Goodness-of-Fit — '+label, y=0.995, fontsize=12)
    fig.savefig(out_png, dpi=220); plt.close(fig)

def preflight_panel_report(df: pd.DataFrame):
    """在估计前做一次面板健康检查：样本量、Y_t 变异、各候选列的缺失/唯一值情况。
    输出到控制台，并写 alaam_preflight_report.csv 方便排查。"""
    OUT = OUT_DIR / 'alaam_preflight_report.csv'
    need_core = [
        'Y_t','province','year','sea_region',
        'two_path_open_t1','two_path_open_t1_ortho_z',
        'logdeg_t1_z','spatial_lag_misfit_t1_z','corridor_betweenness_t1_z',
        'z_prov_count_looi_t1_z','nonNat_share_t1_z','high_misfit_t1'
    ]
    cols = [c for c in need_core if c in df.columns]
    rows = []
    n_total = len(df)
    # Y_t 概况
    if 'Y_t' in df.columns:
        vc = df['Y_t'].value_counts(dropna=False).to_dict()
        print('[preflight] Y_t value_counts:', vc)
        try:
            print('[preflight] Y_t pos_rate:', float(df['Y_t'].mean()))
        except Exception:
            pass
    else:
        print('[preflight][ERR] 面板缺少列 Y_t')
    # 各列缺失与唯一值
    for c in cols:
        non_null = df[c].notna().sum()
        nunique = df[c].nunique(dropna=True)
        rows.append({'col': c, 'n_total': n_total, 'non_null': non_null,
                     'missing': n_total - non_null, 'nunique': nunique})
    rpt = pd.DataFrame(rows)
    try:
        rpt.to_csv(OUT, index=False, encoding='utf-8-sig')
        print('[preflight] 报告已写出：', OUT.resolve())
    except Exception as e:
        print('[preflight][warn] 写出报告失败：', e)

    # 逐列累积 dropna 的“样本杀伤”追踪
    seq = [c for c in ['Y_t','province','year','sea_region','two_path_open_t1_ortho_z',
                       'logdeg_t1_z','spatial_lag_misfit_t1_z','corridor_betweenness_t1_z'] if c in df.columns]
    d = df.copy()
    print('[preflight] 累积 dropna 影响（column → 剩余样本）:')
    remain = len(d)
    print(f'  起始: {remain}')
    for c in seq:
        before = remain
        d = d.dropna(subset=[c])
        remain = len(d)
        print(f'  + 要求 {c} 非空: {before} → {remain}')
    if 'Y_t' in d.columns:
        try:
            print('[preflight] 在最小必需集下，Y_t nunique =', d['Y_t'].nunique(),
                  ' pos_rate =', float(d['Y_t'].mean()))
        except Exception:
            pass



# ========= 主流程 =========
def main():
    np.random.seed(RANDOM_SEED)
    df = load_panel(PANEL_PATH)
    preflight_panel_report(df)

    # ——（不使用）spatial_lag Top-Q4
    if ADD_SPATIAL_TOPQ4 and 'spatial_lag_misfit_t1' in df.columns:
        df = add_topQ4_dummy(df, 'spatial_lag_misfit_t1', 'spatial_lag_topQ4_t1')

    # —— two_path 正交（winsor→按年回归→残差→逐年z）
    if 'two_path_open_t1' in df.columns:
        df['two_path_open_t1_w'] = winsorize_by_year(df, 'two_path_open_t1', q=WINSOR_TWOPATH_Q)
        bases = [b for b in TP_BASES_FOR_ORTHO if b in df.columns]
        _, tp_z = residualize_by_year(df, target='two_path_open_t1_w', bases=bases,
                                      add_z=True, name_prefix='two_path_open_t1_ortho')
        df['two_path_open_t1_ortho_z'] = tp_z

    # —— 组装 RHS（强制保留 corridor_betweenness_t1_z / spatial_lag_misfit_t1_z）
    cont_cols = [c for c in BASE_CONT_Z if c in df.columns]
    # 不加 spatial_lag_topQ4_t1
    use_two_path_term = 'two_path_open_t1_ortho_z' if 'two_path_open_t1_ortho_z' in df.columns else None
    if use_two_path_term:
        cont_cols.append(use_two_path_term)
    bin_cols = [c for c in BINARY_COLS if c in df.columns]
    rhs_full = cont_cols + bin_cols + (['C(sea_region)'] if 'sea_region' in df.columns else []) + (['C(year)'] if 'year' in df.columns else [])
    formula_full = f"{Y_COL} ~ " + (' + '.join(rhs_full) if rhs_full else '1')

    # —— 第一次全量拟合（仅用于识别“极不显著”项）
    d_full, rob_full, f_full = fit_logit(df, formula_full)
    if rob_full is None:
        rob_full = _constant_result(d_full)
    coefs_full = extract_coefs(rob_full, 'pooled_step1_full')
    coefs_full.to_csv(OUT_DIR/'alaam_coefficients_pooled.csv',
                      mode='a' if (OUT_DIR/'alaam_coefficients_pooled.csv').exists() else 'w',
                      index=False, encoding='utf-8-sig')
    print('[info] Step1-FULL 公式：', f_full)

    # —— 精简入模：剔除“极不显著”的非关键项（保留 two_path + 强制外生 + 至少一个结构性）
    drop_candidates = [c for c in ['z_prov_count_looi_t1_z','nonNat_share_t1_z','high_misfit_t1'] if c in rhs_full]
    pmap = {}
    for t in drop_candidates + ['logdeg_t1_z','spatial_lag_misfit_t1_z','corridor_betweenness_t1_z', use_two_path_term]:
        if t is None: continue
        if t in coefs_full['term'].values:
            try:
                pmap[t] = float(coefs_full.loc[coefs_full['term']==t,'p_value'].values[0])
            except Exception:
                pmap[t] = np.nan

    to_drop = [t for t in drop_candidates if (t in pmap and (np.isnan(pmap[t]) or pmap[t] >= PRUNE_PVAL))]
    for t in STRUCTURAL_MIN_KEEP:
        if t in to_drop:
            to_drop.remove(t)
    for t in FORCE_KEEP + ([use_two_path_term] if use_two_path_term else []):
        if t in to_drop:
            to_drop.remove(t)

    if to_drop:
        rhs_pruned = [t for t in rhs_full if (t not in to_drop)]
        formula_step1 = f"{Y_COL} ~ " + (' + '.join(rhs_pruned) if rhs_pruned else '1')
        print('[info] 精简入模，剔除：', to_drop)
    else:
        rhs_pruned = rhs_full[:]
        formula_step1 = f_full

    # —— Step1：在精简 RHS 上拟合
    d_used, rob, f_final = fit_logit(df, formula_step1)
    if rob is None:
        rob = _constant_result(d_used)
    coefs1 = extract_coefs(rob, 'pooled_step1')
    coefs1.to_csv(OUT_DIR/'alaam_coefficients_pooled.csv', mode='a', header=False, index=False, encoding='utf-8-sig')
    print('[info] Step1（精简）公式：', f_final)

    # —— 自适应阈值检查 two-path 显著性（基于“精简模型”的 p 值）
    need_switch = False
    p_tp = None
    if use_two_path_term and use_two_path_term in coefs1['term'].values:
        try:
            p_tp = float(coefs1.loc[coefs1['term']==use_two_path_term, 'p_value'].values[0])
            need_switch = (p_tp >= TP_SWITCH_PVAL)
        except Exception:
            need_switch = False

    if need_switch:
        if use_two_path_term not in d_used.columns:
            print('[warn] two_path_open_t1_ortho_z 缺失，无法生成 Q4；跳过 Step2。')
            gof_plot(d_used, rob, OUT_DIR/'alaam_gof_pooled.png', label='pooled')
        else:
            thr = d_used.groupby('year')[use_two_path_term].transform(lambda s: s.quantile(0.75)) if 'year' in d_used.columns else d_used[use_two_path_term].quantile(0.75)
            d_used['two_path_topQ4_ortho_t1'] = (d_used[use_two_path_term] >= thr).astype(int)

            rhs2 = [x for x in rhs_pruned if x != use_two_path_term] + ['two_path_topQ4_ortho_t1']
            formula2 = f"{Y_COL} ~ " + (' + '.join(rhs2) if rhs2 else '1')
            d_used2, rob2, f_final2 = fit_logit(d_used, formula2)
            if rob2 is None:
                rob2 = _constant_result(d_used2)
            coefs2 = extract_coefs(rob2, 'pooled_step2_topQ4')
            coefs2.to_csv(OUT_DIR/'alaam_coefficients_pooled.csv', mode='a', header=False, index=False, encoding='utf-8-sig')
            gof_plot(d_used2, rob2, OUT_DIR/'alaam_gof_pooled.png', label='pooled_topQ4')

            pred = d_used2[['MPA_ID','year','sea_region','province']].copy() if 'MPA_ID' in d_used2.columns else d_used2[['year','sea_region','province']].copy()
            y = d_used2[Y_COL].to_numpy()
            try:
                pprob = rob2.predict()
            except Exception:
                pprob = np.repeat(y.mean(), len(y))
            pred['y_true'] = y; pred['y_prob'] = pprob
            pred.to_csv(OUT_DIR/'alaam_predictions.csv', index=False, encoding='utf-8-sig')

            with open(OUT_DIR/'two_path_adaptive_choice.txt','w',encoding='utf-8') as f:
                f.write(f'Step1(full)→pruned；two-path 线性项 p={p_tp:.4f} >= {TP_SWITCH_PVAL} → Step2 使用 two_path_topQ4_ortho_t1\n')
                f.write(f'最终公式：{f_final2}\n')

    else:
        gof_plot(d_used, rob, OUT_DIR/'alaam_gof_pooled.png', label='pooled')
        pred = d_used[['MPA_ID','year','sea_region','province']].copy() if 'MPA_ID' in d_used.columns else d_used[['year','sea_region','province']].copy()
        y = d_used[Y_COL].to_numpy()
        try:
            pprob = rob.predict()
        except Exception:
            pprob = np.repeat(y.mean(), len(y))
        pred['y_true'] = y; pred['y_prob'] = pprob
        pred.to_csv(OUT_DIR/'alaam_predictions.csv', index=False, encoding='utf-8-sig')
        with open(OUT_DIR/'two_path_adaptive_choice.txt','w',encoding='utf-8') as f:
            f.write(f'Step1(full)→pruned；保留 {use_two_path_term}（p={p_tp if p_tp is not None else "NA"} < {TP_SWITCH_PVAL}）\n')
            f.write(f'最终公式：{f_final}\n')

    print('[DONE] ALAAM pooled 完成：')
    print(' - alaam_coefficients_pooled.csv（含 step1_full/step1_pruned/step2_topQ4 视情况）')
    print(' - alaam_gof_pooled.png（2×2：ROC/Calibration/PR/Gains） & alaam_gof_metrics.csv（含 AUPRC/KS）')
    print(' - alaam_predictions.csv')
    print(' - two_path_adaptive_choice.txt（自适应阈值决策记录）')



if __name__ == '__main__':
    main()


